# !/usr/bin/env python3
# Copyright (c) 2022 Institute for Quantum Computing, Baidu Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""
a class that holds all the information of the protein model
"""

from typing import List, Tuple, Dict, Optional
import logging
import networkx as nx
from openfermion import QubitOperator
from paddle_quantum import Hamiltonian
from paddle_quantum.biocomputing.data_loader import load_energy_matrix_file
from paddle_quantum.biocomputing.operators import edge_direction_indicator, contact_indicator, backwalk_indicator

__all__ = ["Protein"]

AA_SYMBOLS = [
    "C", "M", "F", "I", "L", "V", "W", "Y", "A", "G", "T", "S", "N", "Q",
    "D", "E", "H", "R", "K", "P"
]


def _check_valid_aa_seq(aa_seq: str) -> bool:
    r"""check if a given amino acid sequence is valid.
    
    Note:
        This is an internal function, users don't intended to use it directly.

    Args:
        aa_seq: Amino acides in the protein.
    
    Return:
        If input is a valid amino acide sequence, will return True, otherwise, will return False.
    """
    for a in aa_seq:
        if a not in AA_SYMBOLS:
            raise ValueError(f"Input amino acid symbol must in {AA_SYMBOLS}, get {a:s}.")
    num_aa = len(aa_seq)
    assert num_aa >= 2, "The smallest allowed number of amino acids in protein is 2."
    return True


def _generate_valid_contact_pairs(num_aa: int) -> List[Tuple[int]]:
    r"""A function that generate the potential contact node pairs. Then number of
    nodes in a protein is given by ``num_aa`` .

    Note:
        This is an internal function, users don't intended to use it directly.

    Args:
        num_aa: Number of amino acides in the protein.
    
    Return:
        Possible contact pairs for the given protein.
    """
    pairs = []
    for i in range(num_aa):
        i1 = i + 5
        if i1 <= num_aa:
            pairs.extend((i, j) for j in range(i1, num_aa) if (j-i) % 2 == 1)
    return pairs


class Protein(nx.Graph):
    r"""
    Protein object will build a protein from given amino acides sequence.
    The 3D structure of the protein is build on a diamond lattice, each bond is
    mapped to an edge in the lattice. We can get the spatial direction of the
    bonds and the distance between the two amino acides from it.

    Args:
        aa_seqs: an ordered string in which each symbol represents an amino acide (aa).
        fixed_bond_directions: a list contains the directions of prefixed
            bonds, they are represented by a (edge, direction) pairs, e.g. (0, 1)->"10" means the 
            bond between 0 and 1 is along the 2nd direction. Default is None, which means no bond
            is fixed a prior.
        contact_pairs: a list of potentially contact amino acides pairs. Default
            is None and will use all the valid pairs generated by ``_generate_valid_contact_pairs`` .

    """
    def __init__(
        self,
        aa_seqs: str,
        fixed_bond_directions: Optional[Dict[Tuple, int]] = None,
        contact_pairs: Optional[List[Tuple[int]]] = None,
    ) -> None:
        super().__init__()
        _check_valid_aa_seq(aa_seqs)

        num_aa = len(aa_seqs)
        config_qblock_idx = 0
        if fixed_bond_directions is None:
            fixed_bond_directions = {}
        
        for e_a, e_b in zip(range(num_aa-1), range(1, num_aa)):
            direction = fixed_bond_directions.get((e_a, e_b))
            if isinstance(direction, int):
                sign, indicator = edge_direction_indicator((e_a, e_b), direction=direction)
            else:
                sign, indicator = edge_direction_indicator((e_a, e_b), [2*config_qblock_idx, 2*config_qblock_idx+1])
                config_qblock_idx += 1
            self.add_edge(e_a, e_b, sign=sign, indicator=indicator)
            self.nodes[e_a]["symbol"] = aa_seqs[e_a]
        self.nodes[e_b]["symbol"] = aa_seqs[e_b]
        self._num_config_qubits = 2*config_qblock_idx

        if contact_pairs is None:
            contact_pairs = _generate_valid_contact_pairs(num_aa)
        self._num_contact_qubits = len(contact_pairs)

        logging.info("\n#######################################\nProtein (lattice model)\n#######################################")
        logging.info(f"Symbols: {'-'.join(aa_seqs):s}")
        logging.info("Lattice: diamond lattice")
        energies, _ = load_energy_matrix_file()
        self._energy_matrix = energies
        logging.info("Assumed contact pairs [energy]:")
        contact_energies = {}
        for p, q in contact_pairs:
            i, j = sorted([AA_SYMBOLS.index(aa_seqs[p]), AA_SYMBOLS.index(aa_seqs[q])])
            en = energies[i, j]
            contact_energies[(p, q)] = en
            logging.info(f"\t({p:d},{q:d}) [{en:.5f}]")
        self.contact_energies = contact_energies

    @property
    def num_config_qubits(self):
        return self._num_config_qubits

    @property
    def num_contact_qubits(self):
        return self._num_contact_qubits

    @property
    def num_qubits(self):
        return self.num_config_qubits + self.num_contact_qubits

    def distance_operator(self, p: int, q: int) -> QubitOperator:
        r"""Distance between p-th and q-th nodes in the protein graph.

        Args:
            p: index of node p
            q: index of node q

        Returns:
            The operator.
        """
        assert p != q, "The node indices shouldn't be equal."
        p, q = sorted([p, q]) # make sure that p<q !!!

        delta_x = [0]*4
        for p in range(p, q):
            attrs: Dict = self.edges[(p, p+1)]
            sign = attrs["sign"]
            for a in range(4):
                delta_x[a] += sign*attrs["indicator"][a]
        return sum(t**2 for t in delta_x)

    def get_protein_hamiltonian(self, lambda0: float = 0.0, lambda1: float = 0.0, energy_multiplier: float = 0.1) -> Hamiltonian:
        r"""
        The Hamiltonian used in VQE algorithm to solve protein folding problem (will take into account the first and
        second neighbor interactions).

        .. math::

            \begin{array}{rcl}
            \hat{H} & = &\lambda_0\hat{H}_{backward} + \sum_{i=0}^{N-1}\sum_{j}q_{ij}(\epsilon_{ij}+\lambda_1(d(i,j)-1))\\
             & & +\lambda_1(4-d(i,neighbor(j)-d(neighbor(i),j)))\\
            \text{where} & & \hat{H}_{backward}=\sum_{e_{i}} (-1)^{e_i[0]+e_{i+1}[0]} f_a(e_i)f_a(e_{i+1})j\in\{k|k>=i+5, (k-i)%2==1\}
            \end{array}

        Args:
            protein: protein for which to build the Hamiltonian.
            lambda0: the penalty factor for the backwalk constraint.
            lambda1: the penalty factor for the contaction constraints.
            
        Returns:
            Hamiltonian.
        """
        h = 0*QubitOperator("")
        contact_edges = sorted(self.contact_energies.keys())
        for (p, q), e_pq in self.contact_energies.items():
            qindex = self.num_config_qubits + contact_edges.index((p, q))
            I_pq = contact_indicator(qindex)
            h += I_pq * (energy_multiplier*e_pq + lambda1*(self.distance_operator(p, q) - 1))
            # restrict the neighbor distance
            # use squared constraint to restrict the positivity.
            for r in self.neighbors(p):
                e2_rq = self._energy_matrix[AA_SYMBOLS.index(self.nodes[r]["symbol"]), AA_SYMBOLS.index(self.nodes[q]["symbol"])]
                h += I_pq * (energy_multiplier*e2_rq + lambda1*(2 - self.distance_operator(r, q))**2)
            for s in self.neighbors(q):
                e2_rq = self._energy_matrix[AA_SYMBOLS.index(self.nodes[p]["symbol"]), AA_SYMBOLS.index(self.nodes[s]["symbol"])]
                h += I_pq * (energy_multiplier*e2_rq + lambda1*(2 - self.distance_operator(p, s))**2)
        
        # forbid backwalk
        edges = sorted(self.edges)
        for i in range(len(edges) - 1):
            e0_attr = self.edges[edges[i]]["indicator"]
            e1_attr = self.edges[edges[i+1]]["indicator"]
            h += lambda0*backwalk_indicator(e0_attr, e1_attr)
        
        return Hamiltonian.from_qubit_operator(h)
